#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


from coremltools import _logger as logger
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Program
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil.mil.types.symbolic import any_variadic, is_symbolic, num_symbolic


@register_pass(namespace="common")
class remove_symbolic_reshape(AbstractGraphPass):
    """
    Convert symbolic shape in ``reshape`` to integers.
    
    Note: This does not perform any optimization, but simply
    replaces symbols with positive integers if solved from volumetric
    constraint, or -1. Therefore, this pass fails if more than one symbol
    needs to be resolved to -1.

    .. code-block::

        # Before remove_symbolic_reshape pass.
        main(%x: (s0, 4, fp32)) {
          block0() {
            %reshape_0_shape_0: (3,i32)^ = const(val=(s0, s1, 2))
            %reshape_0: (s0, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0)
          } -> (%reshape_0)
        }

        # After remove_symbolic_reshape pass.
        main(%x: (s0, 4, fp32)) {
          block0() {
            %reshape_0_shape_0x: (3,i32)* = const(val=[-1, 2, 2])
            %reshape_0: (-1, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0x)
          } -> (%reshape_0)
        }

    TODO (rdar://59165842): Use expand_dims, squeeze etc to use 0 instead of dynamic reshape with -1.
    """

    def apply(self, prog: Program):
        for f in prog.functions.values():
            num_changes = self._remove_symbolic_reshape_block(f)
            msg = "remove_symbolic_reshape: changed {} reshapes."
            logger.info(msg.format(num_changes))

    @block_context_manager
    def _remove_symbolic_reshape_block(self, block):
        num_changes = 0
        for op in list(block.operations):
            for b in op.blocks:
                num_changes += self._remove_symbolic_reshape_block(b)
            if op.op_type != "reshape":
                continue
            if op.shape.val is not None:
                # shape does not contain symbol.
                continue
            if op.shape.sym_val is None:
                # shape is runtime determined.
                continue
            if len(op.shape.child_ops) > 1:
                continue
            # Use output shape as `shape`
            shape = op.outputs[0].shape
            if any_variadic(shape):
                msg = (
                    "Cannot reshape to variadic from a compile time "
                    + "shape argument. Variadic shape can only be achieved "
                    + "via runtime shape argument. op: {}"
                )
                raise ValueError(msg.format(op))
            num_symbols = num_symbolic(shape)
            if num_symbols > 1:
                continue
            # Convert the one symbol to -1
            integer_shape = [-1 if is_symbolic(i) else i for i in shape]
            shape_const = mb.const(
                val=integer_shape,
                name=op.shape.name + "x",
                before_op=op,
            )
            reshaped = mb.reshape(x=op.x, shape=shape_const, name=op.name, before_op=op)
            op.enclosing_block.replace_uses_of_var_after_op(
                anchor_op=op, old_var=op.outputs[0], new_var=reshaped
            )
            # Remove all the ops at once
            block.remove_ops([op, op.shape.op])
            num_changes += 1
        return num_changes
